import numpy as np
import torch
from torch import nn

from visual_text_merge_test.vision_transformer import vit_base, MultiheadSelfAttention


class Transformer(nn.Module):
    def __init__(
        self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
    ):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(
            *[
                MultiheadSelfAttention(
                    dim=width,
                    num_heads=heads,
                )
            ]
        )


class ModelMerge(nn.Module):
    def __init__(self):
        super().__init__()
